"""
Adapted from Jonathan Dursi
https://github.com/ljdursi/poapy
"""

import collections
import textwrap
from typing import Dict, List, Optional, Union

import numpy

from .alignment import SeqGraphAlignment


class Node(object):
    def __init__(self, nodeID: int = -1, text: str = ""):
        self.ID = nodeID
        self.text = text
        self.inEdges = {}
        self.outEdges = {}
        self.alignedTo = []

    def __str__(self):
        return "(%d:%s)" % (self.ID, self.text)

    def _add_edge(
        self,
        edgeset: Dict[int, "Node"],
        neighbourID: int,
        label: Union[int, List[int]],
        from_neighbour: bool,
        weight: int = 1,
    ):
        if neighbourID is None:
            return
        # already present? just update labels
        # otherwise create appropriately-ordered edge and proceed
        if neighbourID in edgeset:
            edgeset[neighbourID].weight += weight
            if isinstance(label, list):
                edgeset[neighbourID].labels.extend(label)
            else:
                edgeset[neighbourID].labels.append(label)
            # remove duplicates
            edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
        else:
            if from_neighbour:
                edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
            else:
                edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
            edgeset[neighbourID] = edge

    def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
        self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)

    def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
        self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)

    def nextNode(self, label: int):
        """Returns the first (presumably only) outward neighbour
        having the given edge label"""
        nextID = None
        for e in self.outEdges:
            if label in self.outEdges[e].labels:
                nextID = e
        return nextID

    @property
    def inDegree(self):
        return len(self.inEdges)

    @property
    def outDegree(self):
        return len(self.outEdges)

    @property
    def weightedInDegree(self):
        return sum(edge.weight for edge in self.inEdges.values())

    @property
    def weightedOutDegree(self):
        return sum(edge.weight for edge in self.outEdges.values())

    @property
    def labels(self):
        """Returns all the labels associated with an in-edge or an out edge."""
        labelset = set([])
        for e in list(self.inEdges.values()):
            labelset = labelset.union(e.labels)
        for e in list(self.outEdges.values()):
            labelset = labelset.union(e.labels)
        return list(labelset)


class Edge(object):
    def __init__(
        self,
        inNodeID: int = -1,
        outNodeID: int = -1,
        label: Optional[Union[int, List[int]]] = None,
        weight: int = 1,
    ):
        self.inNodeID = inNodeID
        self.outNodeID = outNodeID

        self.weight = weight

        if label is None:
            self.labels = []
        elif isinstance(label, list):
            self.labels = label
        else:
            self.labels = [label]

    def addLabel(self, newlabel):
        self.labels.append(newlabel)

    def __str__(self):
        nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
        if self.labels is None:
            return nodestr
        else:
            return nodestr + self.labels.__str__()


class POAGraph(object):
    def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
        """Add a completely independant (sub)string to the graph,
        and return node index to initial and final node"""
        if seq is None:
            return

        firstID, lastID = None, None
        neededSort = self.needsSort

        for text in seq:
            nodeID = self.addNode(text)
            if firstID is None:
                firstID = nodeID
            if lastID is not None:
                self.addEdge(lastID, nodeID, label)
            lastID = nodeID

        self._needsort = neededSort  # no new order problems introduced
        if updateSequences:
            self._seqs.append(seq)
            self._labels.append(label)
            self._starts.append(firstID)
        return firstID, lastID

    def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
        self._nextnodeID = 0
        self._nnodes = 0
        self._nedges = 0
        self.nodedict = {}
        self.nodeidlist = []  # allows a (partial) order to be imposed on the nodes
        self._needsort = False
        self._labels = []
        self._seqs = []
        self._starts = []

        if seq is not None:
            self.addUnmatchedSeq(seq, label)

    def nodeIdxToBase(self, idx):
        return self.nodedict[self.nodeidlist[idx]].text

    def addNode(self, text):
        nid = self._nextnodeID
        newnode = Node(nid, text)
        self.nodedict[nid] = newnode
        self.nodeidlist.append(nid)
        self._nnodes += 1
        self._nextnodeID += 1
        self._needsSort = True
        return nid

    def addEdge(self, start, end, label, weight: int = 1):
        if start is None or end is None:
            return

        if start not in self.nodedict:
            raise KeyError("addEdge: Start node not in graph: " + str(start))
        if end not in self.nodedict:
            raise KeyError("addEdge: End node not in graph: " + str(end))

        oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree

        self.nodedict[start].addOutEdge(end, label, weight)
        self.nodedict[end].addInEdge(start, label, weight)

        newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree

        if newNodeEdges != oldNodeEdges:
            self._nedges += 1

        self._needsSort = True
        return

    @property
    def needsSort(self):
        return self._needsort

    @property
    def nNodes(self):
        return self._nnodes

    @property
    def nEdges(self):
        return self._nedges

    @property
    def num_sequences(self):
        return len(self._seqs)

    def get_sequences(self):
        return self._seqs

    def _simplified_graph_rep(self):

        node_to_pn = {}
        pn_to_nodes = {}

        # Find the mappings from nodes to pseudonodes
        cur_pnid = 0
        for _, node in self.nodedict.items():
            if node.ID not in node_to_pn:
                node_ids = [node.ID] + node.alignedTo
                pn_to_nodes[cur_pnid] = node_ids
                for nid in node_ids:
                    node_to_pn[nid] = cur_pnid
                cur_pnid += 1

        # create the pseudonodes
        Pseudonode = collections.namedtuple(
            "Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
        )
        pseudonodes = []

        for pnid in range(cur_pnid):
            nids, preds, succs = pn_to_nodes[pnid], [], []
            for nid in nids:
                node = self.nodedict[nid]
                preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
                succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]

            pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
            pseudonodes.append(pn)

        return pseudonodes

    def toposort(self):
        """Sorts node list so that all incoming edges come from nodes earlier in the list."""
        sortedlist = []
        completed = set([])

        #
        # The topological sort of this graph is complicated by the alignedTo edges;
        # we want to nodes connected by such edges to remain near each other in the
        # topological sort.
        #
        # Here we'll create a simple version of the graph that merges nodes that
        # are alignedTo each other, performs the sort, and then decomposes the
        # 'pseudonodes'.
        #
        # The need for this suggests that the way the graph is currently represented
        # isn't quite right and needs some rethinking.
        #

        pseudonodes = self._simplified_graph_rep()

        def dfs(start, complete, sortedlist):
            stack, started = [start], set()
            while stack:
                pnodeID = stack.pop()

                if pnodeID in complete:
                    continue

                if pnodeID in started:
                    complete.add(pnodeID)
                    for nid in pseudonodes[pnodeID].node_ids:
                        sortedlist.insert(0, nid)
                    started.remove(pnodeID)
                    continue

                successors = pseudonodes[pnodeID].successors
                started.add(pnodeID)
                stack.append(pnodeID)
                stack.extend(successors)

        while len(sortedlist) < self.nNodes:
            found = None
            for pnid in range(len(pseudonodes)):
                if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
                    found = pnid
                    break
            assert found is not None
            dfs(found, completed, sortedlist)

        assert len(sortedlist) == self.nNodes
        self.nodeidlist = sortedlist
        self._needsSort = False
        return

    def testsort(self):
        """Test the nodeidlist to make sure it is topologically sorted:
        eg, all predecessors of a node preceed the node in the list"""
        if self.nodeidlist is None:
            return
        seen_nodes = set()
        for nodeidx in self.nodeidlist:
            node = self.nodedict[nodeidx]
            for in_neighbour in node.inEdges:
                assert in_neighbour in seen_nodes
            seen_nodes.add(nodeidx)
        return

    def nodeiterator(self):
        if self.needsSort:
            self.toposort()

        def nodegenerator():
            for nodeidx in self.nodeidlist:
                yield self.nodedict[nodeidx]

        return nodegenerator

    def __str__(self):
        selfstr = ""
        ni = self.nodeiterator()
        for node in ni():
            selfstr += node.__str__() + "\n"
            for outIdx in node.outEdges:
                selfstr += "        " + node.outEdges[outIdx].__str__() + "\n"
        return selfstr

    def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
        """Incorporate a SeqGraphAlignment into the graph."""
        newseq = alignment.sequence
        stringidxs = alignment.stringidxs
        nodeidxs = alignment.nodeidxs

        firstID = None
        headID = None
        tailID = None

        path = []
        # head, tail of sequence may be unaligned; just add those into the
        # graph directly
        validstringidxs = [si for si in stringidxs if si is not None]
        startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
        if startSeqIdx > 0:
            firstID, headID = self.addUnmatchedSeq(
                newseq[0:startSeqIdx], label, updateSequences=False
            )
        if endSeqIdx < len(newseq):
            tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)

        # now we march along the aligned part. For each text, we find or create
        # a node in the graph:
        #   - if unmatched, the corresponding node is a new node
        #   - if matched:
        #       - if matched to a node with the same text, the node is that node
        #       - if matched to a node with a different text whch is in turn
        #         aligned to a node with the same text, that aligned node is
        #         the node
        #       - otherwise, we create a new node.
        # In all cases, we create edges (or add labels) threading through the
        # nodes.
        for sindex, matchID in zip(stringidxs, nodeidxs):
            if sindex is None:
                continue
            text = newseq[sindex]
            if matchID is None:
                nodeID = self.addNode(text)
            elif self.nodedict[matchID].text == text:
                nodeID = matchID
            else:
                otherAligns = self.nodedict[matchID].alignedTo
                foundNode = None
                for otherNodeID in otherAligns:
                    if self.nodedict[otherNodeID].text == text:
                        foundNode = otherNodeID
                if foundNode is None:
                    nodeID = self.addNode(text)
                    self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
                    for otherNodeID in [matchID] + otherAligns:
                        self.nodedict[otherNodeID].alignedTo.append(nodeID)
                else:
                    nodeID = foundNode

            self.addEdge(headID, nodeID, label)
            headID = nodeID
            if firstID is None:
                firstID = headID

            path.append(nodeID)

        # finished the unaligned portion: now add an edge from the current headID to the tailID.
        self.addEdge(headID, tailID, label)

        # resort
        self.toposort()

        self._seqs.append(seq)
        self._labels.append(label)
        self._starts.append(firstID)
        self._seq_paths[label] = path
        return

    def consensus(self, excludeLabels=None):
        if excludeLabels is None:
            excludeLabels = []

        if self.needsSort:
            self.toposort()

        nodesInReverse = self.nodeidlist[::-1]
        maxnodeID = max(nodesInReverse) + 1
        nextInPath = [-1] * maxnodeID
        scores = numpy.zeros((maxnodeID))

        for nodeID in nodesInReverse:
            bestWeightScoreEdge = (-1, -1, None)
            for neighbourID in self.nodedict[nodeID].outEdges:
                # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
                e = self.nodedict[nodeID].outEdges[neighbourID]
                weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)

                if weightScoreEdge > bestWeightScoreEdge:
                    bestWeightScoreEdge = weightScoreEdge

            scores[nodeID] = sum(bestWeightScoreEdge[0:2])
            nextInPath[nodeID] = bestWeightScoreEdge[2]

        pos = numpy.argmax(scores)
        path = []
        bases = []
        labels = []
        while pos is not None and pos > -1:
            path.append(pos)
            bases.append(self.nodedict[pos].text)
            labels.append(self.nodedict[pos].labels)
            pos = nextInPath[pos]

        # ignore END node
        path = path[:-1]
        bases = bases[:-1]
        labels = labels[:-1]
        return path, bases, labels

    def allConsenses(self, maxfraction=0.5):
        allpaths = []
        allbases = []
        alllabels = []
        exclusions = []

        passno = 0
        lastlen = 1000
        maxpasses = 10

        while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
            path, bases, labellists = self.consensus(exclusions)
            if len(path) > 0:
                allpaths.append(path)
                allbases.append(bases)
                alllabels.append(labellists)

                labelcounts = collections.defaultdict(int)
                for ll in labellists:
                    for label in ll:
                        labelcounts[label] += 1

                for label, seq in zip(self._labels, self._seqs):
                    if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
                        exclusions.append(label)

            lastlen = len(path)
            passno += 1

        return list(zip(allpaths, allbases, alllabels))

    def generateAlignmentStrings(self):
        """Return a list of strings corresponding to the alignments in the graph"""

        # Step 1: assign node IDs to columns in the output
        #  column_index[node.ID] is the position in the toposorted node list
        #    of the node itself, or the earliest node it is aligned to.
        column_index = {}
        current_column = 0

        # go through nodes in toposort order
        ni = self.nodeiterator()
        for node in ni():
            other_columns = [
                column_index[other] for other in node.alignedTo if other in column_index
            ]
            if other_columns:
                found_idx = min(other_columns)
            else:
                found_idx = current_column
                current_column += 1

            column_index[node.ID] = found_idx

        ncolumns = current_column

        # Step 2: given the column indexes, populate the strings
        #   corresponding to the sequences inserted in the graph
        seqnames = []
        alignstrings = []
        for label, start in zip(self._labels, self._starts):
            seqnames.append(label)
            curnode_id = start
            charlist = ["-"] * ncolumns
            while curnode_id is not None:
                node = self.nodedict[curnode_id]
                charlist[column_index[curnode_id]] = node.text
                curnode_id = node.nextNode(label)
            alignstrings.append("".join(charlist))

        # Step 3: Same as step 2, but with consensus sequences
        consenses = self.allConsenses()
        for i, consensus in enumerate(consenses):
            seqnames.append("Consensus" + str(i))
            charlist = ["-"] * ncolumns
            for path, text in zip(consensus[0], consensus[1]):
                charlist[column_index[path]] = text
            alignstrings.append("".join(charlist))

        return list(zip(seqnames, alignstrings))

    def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
        """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""

        # get the consensus sequence, which we'll use as the "spine" of the
        # graph
        pathdict = {}
        if annotate_consensus:
            path, __, __ = self.consensus()
        lines = ["var nodes = ["]

        ni = self.nodeiterator()
        count = 0
        for node in ni():
            line = "    {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
            if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
                line += (
                    ", x: "
                    + str(pathdict[node.ID])
                    + ", y: 0 , fixed: { x:true, y:false},"
                    + "color: '#7BE141', is_consensus:true},"
                )
            else:
                line += "},"
            lines.append(line)

        lines[-1] = lines[-1][:-1]
        lines.append("];")

        lines.append(" ")

        lines.append("var edges = [")
        ni = self.nodeiterator()
        for node in ni():
            nodeID = str(node.ID)
            for edge in node.outEdges:
                target = str(edge)
                weight = str(len(node.outEdges[edge].labels) + 1.5)
                lines.append(
                    "    {from: "
                    + nodeID
                    + ", to: "
                    + target
                    + ", value: "
                    + weight
                    + ", color: '#4b72b0', arrows: 'to'},"
                )
            if verbose:
                for alignededge in node.alignedTo:
                    # These edges indicate alignment to different bases, and are
                    # undirected; thus make sure we only plot them once:
                    if node.ID > alignededge:
                        continue
                    target = str(alignededge)
                    lines.append(
                        "    {from: "
                        + nodeID
                        + ", to: "
                        + target
                        + ', value: 1, style: "dash-line", color: "red"},'
                    )

        lines[-1] = lines[-1][:-1]
        lines.append("];")
        return lines

    def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
        header = """
                  <!doctype html>
                  <html>
                  <head>
                    <title>POA Graph Alignment</title>

                    <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
                  </head>

                  <body>

                  <div id="loadingProgress">0%</div>

                  <div id="mynetwork"></div>

                  <script type="text/javascript">
                    // create a network
                  """
        outfile.write(textwrap.dedent(header[1:]))
        lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
        for line in lines:
            outfile.write(line + "\n")
        footer = """
                  var container = document.getElementById('mynetwork');
                  var data= {
                    nodes: nodes,
                    edges: edges,
                  };
                  var options = {
                    width: '100%',
                    height: '800px',
                    physics: {
                        enabled: false,
                        stabilization: {
                            updateInterval: 10,
                        },
                        hierarchicalRepulsion: {
                            avoidOverlap: 0.9,
                        },
                    },
                    edges: {
                        color: {
                            inherit: false
                        }
                    },
                    layout: {
                        hierarchical: {
                            direction: "UD",
                            sortMethod: "directed",
                            shakeTowards: "roots",
                            levelSeparation: 150, // Adjust as needed
                            nodeSpacing: 100, // Adjust as needed
                            treeSpacing: 200, // Adjust as needed
                            parentCentralization: true,
                        }
                    }
                  };
                  var network = new vis.Network(container, data, options);
                  
                  network.on('beforeDrawing', function(ctx) {
                    nodes.forEach(function(node) {
                        if (node.isConsensus) {
                            // Set the level of spine nodes to the bottom
                            network.body.data.nodes.update({
                                id: node.id,
                                level: 0 // Set level to 0 for spine nodes
                            });
                        }
                    });
                });

                  network.on("stabilizationProgress", function (params) {
                    document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
                  });
                  network.once("stabilizationIterationsDone", function () {
                      document.getElementById("loadingProgress").innerText = "100%";
                      setTimeout(function () {
                        document.getElementById("loadingProgress").style.display = "none";
                      }, 500);
                  });
                </script>

                </body>
                </html>
                """
        outfile.write(textwrap.dedent(footer))
